import torch
from run_Ours_genIter import get_MMD_values_uneven
from utils import set_deterministic, save_results
from data_utils import assign_data

import torch
from tqdm import tqdm
import argparse
from os.path import join as oj

baseline = 'Ours'

class options:
    cuda = torch.cuda.is_available()
    batch_size = 256
    image_size = 32
    n_filters = 100
    steps = 10000
    mmd_batch_size = 1024

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='Process which dataset to run')
    parser.add_argument('-N', '--N', help='Number of data vendors.', type=int, required=True, default=5)
    parser.add_argument('-m', '--size', help='Size of sample datasets.', type=int, required=True, default=1500)
    parser.add_argument('-P', '--dataset', help='Pick the dataset to run.', type=str, required=True)
    parser.add_argument('-Q', '--Q_dataset', help='Pick the Q dataset.', type=str, required=False, choices=['normal', 'EMNIST', 'FaMNIST', 'CIFAR100' , 'CreditCard', 'UGR16'])
    parser.add_argument('-n_t', '--n_trials', help='Number of trials.', type=int, default=5)
    parser.add_argument('-nh', '--not_huber', help='Not with huber, meaning with other types of specified heterogeneity.', action='store_true')
    parser.add_argument('-het', '--heterogeneity', help='Type of heterogeneity.', type=str, default='normal', choices=['normal', 'label', 'classimbalance', 'classimbalance_inter'])
    parser.add_argument('-kde', dest='gmm', help='Whether to use KDE for generator distribution. Only applicable to CreditCard or TON dataset.', action='store_false')
    parser.add_argument('-gmm', dest='gmm', help='Whether to use GMM for generator distribution. Only applicable to CreditCard or TON dataset.', action='store_true')

    # parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false')
    # parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true')

    cmd_args = parser.parse_args()
    print(cmd_args)

    dataset = cmd_args.dataset
    Q_dataset = cmd_args.Q_dataset
    N = cmd_args.N
    size = cmd_args.size
    n_trials = cmd_args.n_trials
    not_huber = cmd_args.not_huber
    heterogeneity = cmd_args.heterogeneity
    use_GMM = cmd_args.gmm

    
    if dataset == 'MNIST': 
        options.mmd_batch_size = 256
    elif dataset == 'CIFAR10':
        options.mmd_batch_size = 256
        options.steps = 20000

    print(f"----- Running sample complexity experiment for {baseline} -----")

    set_deterministic()
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    values_over_trials, values_hat_over_trials = [], []

    sample_size_pcts = [0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
    values_over_trials_sample_sizes = [[] for pct in sample_size_pcts]
    values_hat_over_trials_sample_sizes = [[] for pct in sample_size_pcts]

    # the baseline of MMD2: 
    # when a reference is given, use the reference, but the valuation is based on MMD squared
    values_mmd2_over_trials, values_hat_mmd2_over_trials = [], []


    values_mmd2_over_trials_sample_sizes =  [[] for pct in sample_size_pcts]
    values_hat_mmd2_over_trials_sample_sizes =  [[] for pct in sample_size_pcts]

    linf_error = [[] for pct in sample_size_pcts]
    l1_errors = [[] for pct in sample_size_pcts]
    l2_errors = [[] for pct in sample_size_pcts]
    inversions = [[] for pct in sample_size_pcts]

    n_avgs = 5
    for _ in tqdm(range(n_trials), desc =f'A total of {n_trials} trials.'):
        # raw data
        D_Xs, D_Ys, V_X, V_Y, labels = assign_data(N, size, dataset, Q_dataset, not_huber, heterogeneity)

        reference = torch.cat(D_Xs) # i.e., P_N
        netD = None
        print(f"The shape of reference (via union): {reference.shape}.")
        try:
            MMD_values = get_MMD_values_uneven(D_Xs, D_Ys, V_X, None, netD, device, None, batch_size=options.batch_size)
            values_over_trials.append(MMD_values)

            MMD2_values = get_MMD_values_uneven(D_Xs, D_Ys, V_X, None, netD, device, None, squared=True, batch_size=options.batch_size)
            values_mmd2_over_trials.append(MMD2_values)

            MMD_values_hat = get_MMD_values_uneven(D_Xs, D_Ys, reference, None, netD, device, None, batch_size=options.batch_size)
            values_hat_over_trials.append(MMD_values_hat)

            MMD2_values_hat = get_MMD_values_uneven(D_Xs, D_Ys, reference, None, netD, device, None, squared=True, batch_size=options.batch_size)
            values_hat_mmd2_over_trials.append(MMD2_values_hat)


            print(f"Sample sizes are: {[int(sample_size_pct * size) for sample_size_pct in sample_size_pcts]}. ")
            for j, sample_size_pct in enumerate(sample_size_pcts):
                sample_size = int(sample_size_pct * size)

                MMD_values = get_MMD_values_uneven(D_Xs, None, V_X, None, netD, device, sample_size, squared=False, batch_size=options.batch_size)
                values_over_trials_sample_sizes[j].append(MMD_values)

                MMD_values_hat = get_MMD_values_uneven(D_Xs, None, reference, None, netD, device, sample_size, squared=False, batch_size=options.batch_size)
                values_hat_over_trials_sample_sizes[j].append(MMD_values_hat)

                MMD2_values_mix_half = get_MMD_values_uneven(D_Xs, D_Ys, V_X, None, netD, device, sample_size, squared=True, batch_size=options.batch_size)
                values_mmd2_over_trials_sample_sizes[j].append(MMD2_values_mix_half)

                MMD2_values_hat_mix_half = get_MMD_values_uneven(D_Xs, D_Ys, reference, None, netD, device, sample_size, squared=True, batch_size=options.batch_size)
                values_hat_mmd2_over_trials_sample_sizes[j].append(MMD2_values_hat_mix_half)

        except RuntimeError as e: # Cuda Memory issue
            if str(e).startswith('CUDA out of memory.'):
                print('CUDA out of memory.')
            
            raise Exception


    if not_huber:
        exp_name = oj('not_huber', f'{dataset}_vs_{heterogeneity}-N{N} m{size} n_trials{n_trials}')
    else:
        exp_name = f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials}'
    exp_name = oj('sample_complexity', exp_name)

    results = {'values_over_trials': values_over_trials, 'values_hat_over_trials': values_hat_over_trials, 
    'N':N, 'size':size, 'n_trials': n_trials, 'isHuber':not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}
    save_results(baseline=baseline, exp_name=exp_name, **results)

    results = {'values_over_trials': values_mmd2_over_trials, 'values_hat_over_trials': values_hat_mmd2_over_trials, 
    'N':N, 'size':size, 'n_trials': n_trials, 'isHuber':not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}
    save_results(baseline='MMD_sq_half_mix', exp_name=exp_name, **results)


    for pct, values_over_trials_sample_size, values_hat_over_trials_sample_size in zip(sample_size_pcts, values_over_trials_sample_sizes, values_hat_over_trials_sample_sizes):
        results = {'values_over_trials_sample_size': values_over_trials_sample_size, 'values_hat_over_trials_sample_size': values_hat_over_trials_sample_size, 
        'N':N, 'size':size, 'n_trials': n_trials, 'isHuber': not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}
        save_results(baseline=baseline+f'_sample_complexity_{pct}', exp_name=exp_name, **results)


    for pct, values_mmd2_over_trials_sample_size, values_hat_mmd2_over_trials_sample_size in zip(sample_size_pcts, values_mmd2_over_trials_sample_sizes, values_hat_mmd2_over_trials_sample_sizes):
        results = {'values_over_trials_sample_size': values_mmd2_over_trials_sample_size, 'values_hat_over_trials_sample_size': values_hat_mmd2_over_trials_sample_size, 
        'N':N, 'size':size, 'n_trials': n_trials, 'isHuber': not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}
        save_results(baseline=f'MMD_sq_half_mix_sample_complexity_{pct}', exp_name=exp_name, **results)
